package com.prodog.database.wrapper;

import com.mongodb.client.result.DeleteResult;
import com.prodog.utils.bean.BeanUtil;
import lombok.Data;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.stereotype.Component;

;
import java.lang.reflect.Field;
import java.util.Collection;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors;

@Component
@Data
public class MongoDataWrapper<T, P> extends AbstractDataWrapper<T, P> {
    public MongoTemplate template;

    public MongoDataWrapper() {
    }

    @Autowired
    public void initTemplate(MongoTemplate template) {
        this.template = template;
    }

    private Object getFieldVal(T data, String name) {
        try {
            Field field = getTypeClass().getDeclaredField(name);
            field.setAccessible(true);
            return field.get(data);
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    @Override
    public boolean insert(T obj) {
        try {
            template.insert(BeanUtil.bean2Bean(obj, getTypeClass()));
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    @Override
    public boolean insert(T obj, String fileName) {
        try {
            insert(obj);
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    @Override
    public boolean update(T obj) {
        P objId = (P) getFieldVal(obj, "id");
        if (getById(objId) == null) {
            return false;
        } else {
            save(obj);
            return true;
        }
    }

    @Override
    public boolean save(T obj) {
        try {
            template.save(BeanUtil.bean2Bean(obj, getTypeClass()));
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    @Override
    public T getById(P id) {
        return template.findById(id, getTypeClass());
    }

    @Override
    public List<T> getByIds(Collection<P> ids) {
        Query query = new Query(Criteria.where("_id").is(ids));
        return template.find(query, getTypeClass());
    }

    @Override
    public boolean removeById(P id) {
        Query query = new Query(Criteria.where("_id").is(id));
        DeleteResult remove = template.remove(query, getTypeClass());
        return remove.getDeletedCount() > 0;
    }

    @Override
    public boolean removeByIds(Collection<P> ids) {
        Query query = new Query(Criteria.where("_id").in(ids));
        DeleteResult remove = template.remove(query, getTypeClass());
        return remove.getDeletedCount() > 0;
    }

    @Override
    public List<T> list() {
        return template.findAll(getTypeClass());
    }

    /***
     * 这个emmmm 人比较少 随便啦
     * @param predicate
     * @return
     */
    @Override
    public List<T> list(Predicate<T> predicate) {
        return list().stream().filter(predicate).collect(Collectors.toList());
    }

    @Override
    public T getByColumn(String column, Object val) {
        Query query = new Query();
        query.addCriteria(Criteria.where(column).is(val));
        return template.findOne(query, getTypeClass());
    }

    @Override
    public List<T> listByColumn(String column, Object val) {
        Query query = new Query();
        query.addCriteria(Criteria.where(column).is(val));
        return template.find(query, getTypeClass());
    }

    @Override
    public long countByColumn(String column, Object val) {
        Query query = new Query();
        query.addCriteria(Criteria.where(column).is(val));
        return template.count(query, getTypeClass());
    }

    @Override
    public long max(String column) {
        long max = 0;
        for (T data : list()) {
            long curr = (long) getFieldVal(data, column);
            if (curr > max) {
                max = curr;
            }
        }
        return max;
    }

    @Override
    public T getByColumns(Object... items) {
        Query query = new Query();
        for (int i = 0; i < items.length; i += 2) {
            query.addCriteria(Criteria.where((String) items[i]).is(items[i + 1]));
        }
        return template.findOne(query, getTypeClass());
    }

    @Override
    public List<T> listByColumns(Object... items) {
        Query query = new Query();
        for (int i = 0; i < items.length; i += 2) {
            query.addCriteria(Criteria.where((String) items[i]).is(items[i + 1]));
        }
        return template.find(query, getTypeClass());
    }
}
